import json
import time

from fastapi import Depends, Request

from app.models.system import Users
from app.utils.custom_exc import CustomException
from app.utils.security import verify_access_token
from app.utils.utils import get_request_ip


class IPThrottle:
    """
    IP限速

    使用：
    class MyThrottle(IPThrottle):
        rate = "1/m"

    ip=Depends(MyThrottle.allow_request)

    """
    rate: str = "1/s"

    @staticmethod
    def parse_rate(rate):
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return num_requests, duration

    @classmethod
    async def allow_request(cls, request: Request):
        ip = get_request_ip(request)
        if ip == "127.0.0.1":
            return ip
        elif ip == "unknown":
            raise CustomException("不支持的IP")
        # print(ip)
        r = request.app.state.redis
        path = request.url.path
        cache_format = f'throttle_ip_{ip}_{path}'

        num_requests, duration = cls.parse_rate(cls.rate)

        # history = r.get(cache_format)
        history = await r.get(cache_format)
        # print(history)
        history = json.loads(history) if history else []
        now = time.time()
        # print(history, num_requests, duration)
        while history and history[-1] <= now - duration:
            history.pop()
        if len(history) >= num_requests:
            raise CustomException("请求太快，请稍后...")
        history.insert(0, now)
        await r.set(cache_format, json.dumps(history), duration)  # type: ignore
        return ip


class UserThrottle:
    """
        用户限速
        使用：
        class MyThrottle(UserThrottle):
            rate = "1/m"

        user=Depends(MyThrottle.allow_request)
    """
    rate: str = "1/s"

    @staticmethod
    def parse_rate(rate):
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return num_requests, duration

    @classmethod
    async def allow_request(cls, request: Request, user: Users = Depends(verify_access_token)):

        r = request.app.state.redis
        path = request.url.path
        cache_format = f'throttle_user_{user.id}_{path}'

        num_requests, duration = cls.parse_rate(cls.rate)

        # history = r.get(cache_format)
        history = await r.get(cache_format)
        # print(history)
        history = json.loads(history) if history else []
        now = time.time()
        # print(history, num_requests, duration)
        while history and history[-1] <= now - duration:
            history.pop()
        if len(history) >= num_requests:
            raise CustomException("请求太快，请稍后...")
        history.insert(0, now)
        await r.set(cache_format, json.dumps(history), duration)  # type: ignore
        return user


def ip_rate_throttle(t_rate: str, request: Request):

    def parse_rate(rate):
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return num_requests, duration

    ip = get_request_ip(request)
    if ip == "127.0.0.1":
        return
    elif ip == "unknown":
        raise CustomException("不支持")
    r = request.state.redis
    path = request.url.path
    cache_format = f'throttle_{ip}_{path}'

    num_requests, duration = parse_rate(t_rate)

    history = r.get(cache_format, [])   # type: ignore
    now = time.time()
    while history and history[-1] <= now - duration:
        history.pop()
    if len(history) >= num_requests:
        raise CustomException("请求太快，请稍后...")
    history.insert(0, now)
    r.set(cache_format, history, duration)  # type: ignore
    return True

